# -*- coding: utf-8 -*-
"""
Created on Thu Jan 18 16:46:07 2019

@author: william

Email: hua_yan_tsn@163.com
"""
from sklearn.svm import SVC
from sklearn.externals import joblib
class LinearSVM:
    def __init__(self):
        self.model = None # 预测结果的模型
    def predict(self, X):
        """
        :param X: X样本的vector| tensor
        :return: 预测的结果vector| tensor
        """
        return self.model.predict(X)
    def train(self, X, Y):
        """
        :param X: 需要拟合的X
        :param Y: 需要拟合的Y
        :return: None
        """
        self.model = SVC(kernel='linear') # 论文中使用的是linear SVMs
        self.model.fit(X, Y)
    def save(self, model_name):
        """
        :param model_name: 存储模型的文件名字
        :return:
        """
        joblib.dump(self.model, model_name)

    def load(self, model_name):
        """
        :param model_name:存储模型的文件名
        :return:
        """
        joblib.load(model_name)